package com.kxmall.market.core.ratelimit.aspect;

import cn.hutool.core.util.StrUtil;
import com.kxmall.market.core.exception.AppServiceException;
import com.kxmall.market.core.exception.ExceptionDefinition;
import com.kxmall.market.core.ratelimit.annotation.RateLimiter;
import com.kxmall.market.core.util.IpUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
 * <p>
 * 限流切面
 * </p>
 *
 * @author fy
 * @date Created in 2020/3/7 10:31
 */
@Slf4j
@Aspect
@Component
@SuppressWarnings("all")
@RequiredArgsConstructor(onConstructor_ = @Autowired)
public class RateLimiterAspect {

    private final static String SEPARATOR = ":";
    private final static String REDIS_LIMIT_KEY_PREFIX = "limit:";
    private final StringRedisTemplate stringRedisTemplate;
    private final RedisScript<Long> limitRedisScript;

    @Pointcut("@annotation(com.kxmall.market.core.ratelimit.annotation.RateLimiter)")
    public void rateLimit() {

    }

    @Around("rateLimit()")
    public Object pointcut(ProceedingJoinPoint point) throws Throwable {
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        // 通过 AnnotationUtils.findAnnotation 获取 RateLimiter 注解
        RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class);
        if (rateLimiter != null) {
            String key = rateLimiter.key();
            // 默认用类名+方法名做限流的 key 前缀
            if (StrUtil.isBlank(key)) {
                //key = method.getDeclaringClass().getName() + StrUtil.DOT + method.getName();
                StringBuilder stringBuilder = new StringBuilder();
                stringBuilder.append(method.getDeclaringClass().getName()).append(StrUtil.DOT).append(method.getName());
                // 此时考虑局域网多用户访问的情况，因此 key 需要加上方法参数更加合理
                List<String> paramList = new ArrayList<>();
                List<String> collectParams = argsToStrList(point.getArgs());
                if (collectParams != null) {
                    paramList.addAll(collectParams);
                }
                key = String.format("%s{%s}", stringBuilder.toString(), String.join("&", paramList));
                //  TODO 考虑用户ID
            }
            // 最终限流的 key 为 前缀 + IP地址
            key = key + SEPARATOR + IpUtil.getIpAddr();
            long max = rateLimiter.max();
            long timeout = rateLimiter.timeout();
            TimeUnit timeUnit = rateLimiter.timeUnit();
            boolean limited = shouldLimited(key, max, timeout, timeUnit);
            if (limited) {
                throw new AppServiceException(ExceptionDefinition.OPERATION_TOO_OFTEN);
            }
        }
        return point.proceed();
    }

    private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) {
        // 最终的 key 格式为：
        // limit:自定义key:IP
        // limit:类名.方法名:IP
        key = REDIS_LIMIT_KEY_PREFIX + key;
        // 统一使用单位毫秒
        long ttl = timeUnit.toMillis(timeout);
        // 当前时间毫秒数
        long now = Instant.now().toEpochMilli();
        long expired = now - ttl;
        // 注意这里必须转为 String,否则会报错 java.lang.Long cannot be cast to java.lang.String
        Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + "");
        if (executeTimes != null) {
            if (executeTimes == 0) {
                log.error("【{}】在单位时间 {} 毫秒内已达到访问上限，当前接口上限 {}", key, ttl, max);
                return true;
            } else {
                log.info("【{}】在单位时间 {} 毫秒内访问 {} 次", key, ttl, executeTimes);
                return false;
            }
        }
        return false;
    }


    public static List<String> argsToStrList(Object... args) {
        if (args != null && args.length > 0) {
            List<String> collectParams = new LinkedList<>();
            for (Object param : args) {
                if (param != null) {
                    collectParams.add(String.valueOf(param));
                }
            }
            if (collectParams.size() > 0) {
                return collectParams;
            }
        }
        return null;
    }


}
